import numpy as np

CLASS_DICT = {
    "其他": 0,
    "玉米": 1,
    "大豆": 2,
    "水稻": 3
}

DICT2COLOR = {
    0: (0, 0, 0),
    1: (255, 0, 0),
    2: (0, 255, 0),
    3: (0, 0, 255)
}
# 其他区域、玉米、大豆和水稻像素值分别为0,20,40,60
DICT2CLASSKEY = {
    0: 0,
    1: 20,
    2: 40,
    3: 60
}


def image2label(im):
    # 输入为标记图像的矩阵，输出为单通道映射的label图像
    data = im.astype('int32')
    idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
    return np.array(COLOR_LABEL_MAP[idx])


def color2labelmp():
    cm2lbl = np.zeros(256 ** 3)
    for cm in DICT2COLOR.items():
        id = cm[0]
        cm = cm[1]
        cm2lbl[(cm[0] * 256 + cm[1]) * 256 + cm[2]] = DICT2CLASSKEY.get(id)
    return cm2lbl


COLOR_LABEL_MAP = color2labelmp()
